# -*- codeing = utf-8 -*-
# @Time : 2024/5/4 17:22
# @Author : huangjing
# @File : lenet_eval.py
# @Software : PyCharm
import mindspore as ms
from mindvision.dataset import Mnist
from mindvision.classification.models import lenet
#加载数据集
mnist = Mnist("./MNIST_Data/", split="test", batch_size=32, resize=32)
dataset_eval = mnist.run()
#声明神经网络
network = lenet(num_classes=10, pretrained=False)
#加载ckpt
param_dict = ms.load_checkpoint("./lenet/lenet_1-1_1875.ckpt")
param_not_load, _ = ms.load_param_into_net(network, param_dict)
#推理
network.set_train(False)
for data, label in dataset_eval:
    pred = network(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:]}"\nActual: "{label[:]}"')
    break
